#####
# Script to generate figure 7.
#####

library(ggplot2)
library(cowplot)
library(viridis)
library(RColorBrewer)
library(patchwork)
library(tidyverse)
library(data.table)

library(pracma)
library(nleqslv)
library(rootSolve)
library(fields)
library(chebpol)
library(scales)

library(doParallel)
library(Rcpp)

library(simFnsEqns)

source("equations.r")
source("simulation_functions.r")
source("planner_simulation.r")

avg_sat_decay <- 1 # 1: satellites are infinitely lived
p <- 1
F <- 35
discount <- 0.99
r <- ((1 - discount)/(discount))
fe_eqm <- p/F - r #the equilibrium risk
d <- 0.45
m <- 2
aSS <- 0.005
aSD <- 0.005
aDD <- 0.025 
bSS <- 10
bSD <- 5.375
bDD <- 0.9
Xbar <- 100
T <- 150
params <- data.frame(p=p, F=F, discount=discount, r=r, fe_eqm=fe_eqm, d=d, m=m, aSS=aSS, aSD=aSD, aDD=aDD, bSS=bSS, bSD=bSD, bDD=bDD, T=T, Xbar=Xbar)
params2 <- params
params2$fe_eqm <- params2$fe_eqm*2

#####
# generate a satellite-debris grid
# Create a larger grid with length.out points in both the S and D dimensions
S_element <- (seq(0,4,length.out=200))
D_element <- (seq(0,4,length.out=200))
SD_grid <- expand.grid(S_element,D_element)
colnames(SD_grid) <- c("satellites","debris")

# Create a smaller grid with length.out points in both the S and D dimensions
S_element_s <- seq(0,4,length.out=20)
D_element_s <- seq(0,4,length.out=20)
SD_grid_s <- expand.grid(S_element_s,D_element_s)
colnames(SD_grid_s) <- c("S","D")
grid_list <- list(S = S_element, D = D_element)

# Determine the scaling factor K to expand the grid, based on the maximum values in the S and D dimensions of the smaller grid
M <- length(S_element)
message("Expanding grid area...")
M <- length(S_element)
K <- 4/max(S_element)
S_element_big <- seq(0,4,length.out=(K*M - (K-1)))
M <- length(D_element)
K <- 4/max(D_element)
D_element_big <- seq(0,4,length.out=(K*M - (K-1)))

# Create the expanded grid with the same nodes as the smaller grid
SD_grid_big <- expand_grid(S_element_big,D_element_big)

# The "expanded grid" code calculates the scaling factor K as the ratio of the maximum value in each dimension of the larger grid to the corresponding maximum value in the smaller grid. It then uses this scaling factor to determine the number of points needed to cover the same range in each dimension of the expanded grid (S_element_big and D_element_big). Finally, it creates the expanded grid with the expand.grid() function using the new sets of points in each dimension.

# This code should work generally, as long as the input grids have the same ranges in both dimensions and the scaling factor K is an integer. If the scaling factor is not an integer, the expanded grid may not cover the same range as the original grid. Additionally, if the input grids have different ranges in the two dimensions, the resulting expanded grid may not have nodes at all the same points as the original grid.
#####

solver_time <- proc.time()[3]
ncores <- 80
message("Initializing workers...")
vfi_cl <- makeCluster(ncores)
registerDoParallel(vfi_cl)
message("Generating guesses...")
guess_dfrm <- guess_builder(params, grid_dfrm=as.data.frame(SD_grid_s), grid_dfrm_big=as.data.frame(SD_grid), T=150, label="plan-overshooting", policy_iteration=1, PFI_T=10, fig4=TRUE)
write_csv(guess_dfrm, "../../data/VFIguess_plan-overshooting.csv")
guess_dfrm <- read_csv("../../data/VFIguess_plan-overshooting.csv")
# output <- guess_dfrm
message("Solving DP problems...")
output <- dp_solver(params, guess_dfrm, grid_list)
stopCluster(vfi_cl)
solver_time <- proc.time()[3] - solver_time
message("Time to solve planner problems: ",round(solver_time/60,3)," minutes.")
write_csv(output, "../../data/VFI-soln_plan-overshooting.csv")
output <- read_csv("../../data/VFI-soln_plan-overshooting.csv")

grid_list <- list(S = S_element_big, D = D_element_big)
colnames(SD_grid_big) <- c("S","D")
output_big <- full_join(SD_grid_big, output, by=c("S","D"))
output_big$X[is.na(output_big$X)] <- 0
opt_launch_rate <- ipol(output$X, grid=grid_list, method="multilinear")

## Calibration scale factors.
s_scale_factor <- 68289.84/3.728625
d_scale_factor <- 367372.3/4.214421

# Create the launch_rates vector
launch_rates <- apply(SD_grid, 1, function(x) opt_launch_rate(c(x[1], x[2])))

# Combine SD_grid and launch_rates column-wise
SD_grid_with_rates <- cbind(SD_grid, launch_rates)

# calculate optimal values
fig5_dfrm_u <- data.frame(SD_grid_with_rates) %>% 
				mutate(oa_S_ = S_(launch_rates,satellites,debris,params)) %>%
				mutate(oa_D_ = D_(launch_rates,satellites,debris,params)) %>%
				mutate(numerical_coll_rate = L(oa_S_,oa_D_,params)) %>%
				mutate(sat_null = (oa_S_ - satellites)/(0.1)) %>%
				mutate(deb_null = (oa_D_ - debris)/(0.1))

init_launchpath <- rep(0, length.out=T)

lp_d1 <- generate_time_series(initial_condition=c(0,0),launch_policy=output$X,grid_list,n_periods=params$T,params)

lp_s2 <- generate_time_series(initial_condition=c(2,0),launch_policy=output$X,grid_list,n_periods=params$T,params)

colnames(lp_d1) <- c("time", "launches", "sats", "debs")
colnames(lp_s2) <- c("time", "launches", "sats", "debs")

opt_ss <- c(lp_d1$sats[nrow(lp_d1)],lp_d1$debs[nrow(lp_d1)])

opt_ss_ray_start <- opt_ss - c(opt_ss[2]/m,opt_ss[2]) # Calculate the starting point of the ray by subtracting opt_ss[2]/m from the first element of opt_ss and subtracting opt_ss[2] from the second element of opt_ss.
ray_pt_to_trace <- opt_ss_ray_start+c(0.2,m*0.2) # Define a point along the ray to trace by adding 0.2 to the first element of opt_ss_ray_start and adding m*0.2 to the second element of opt_ss_ray_start.
ray_pt_origin <- physics_inverter(ray_pt_to_trace,params) # Use the function physics_inverter to trace the physical dynamics of the system backwards one step from the point ray_pt_to_trace. The resulting point is stored in ray_pt_origin.
opt_ss_ray <- cbind(S=seq(opt_ss_ray_start[1],opt_ss[1],length.out=length(S_element)),
								 D=seq(opt_ss_ray_start[2],opt_ss[2],length.out=length(S_element))) # Create a matrix opt_ss_ray with two columns, S and D. The S column contains a sequence of numbers starting from the first element of opt_ss_ray_start[1] and ending at the first element of opt_ss[1], with length.out equal to the length of S_element. The D column contains a sequence of numbers starting from the second element of opt_ss_ray_start and ending at the second element of opt_ss, also with length.out equal to the length of S_element.
opt_ss_ray_origin_surface <- data.frame(matrix(-1,nrow=nrow(opt_ss_ray),ncol=ncol(opt_ss_ray))) #Create a data frame opt_ss_ray_origin_surface with the same number of rows as opt_ss_ray and two columns filled with -1.
for(i in 1:nrow(opt_ss_ray)) {
	opt_ss_ray_origin_surface[i,] <- physics_inverter(opt_ss_ray[i,],params)
} # Use the function physics_inverter to trace the physical dynamics of the system backwards one step from each point in opt_ss_ray. The resulting points are stored in opt_ss_ray_origin_surface.
colnames(opt_ss_ray_origin_surface) <- c("sats","debs") # Rename the columns of opt_ss_ray_origin_surface to "sats" and "debs".

opt_ss_ray_origin_surface_scaled <- opt_ss_ray_origin_surface %>%  # nolint
	mutate(sats=sats*s_scale_factor,
				 debs=debs*d_scale_factor)

initial_condition_sj1 <- opt_ss_ray_origin_surface[57,] %>% unlist() %>% as.numeric()

lp_sj1 <- generate_time_series(initial_condition=initial_condition_sj1,launch_policy=output$X,grid_list,n_periods=params$T,params) # doesn't seem like this is exactly a one-step frontier? 57 is exactly the point where it gets positive; check back to ray_start_index if updating params

colnames(lp_sj1) <- c("time", "launches", "sats", "debs")

fig5_time_dfrm <- left_join( as.data.frame(
	as_tibble(lp_d1) %>% 
		mutate(
			launches=launches*s_scale_factor,
			satellites=sats*s_scale_factor,
			debris=debs*d_scale_factor)
), as.data.frame(as_tibble(lp_s2) %>%
	mutate(
		launches=launches*s_scale_factor,
		satellites=sats*s_scale_factor,
		debris=debs*d_scale_factor)
), by=c("time"), suffix=c("_d1","_s2"))
fig5_time_dfrm <- left_join( fig5_time_dfrm, as.data.frame(
	as_tibble(lp_sj1) %>%
	mutate(
		launches=launches*s_scale_factor,
		satellites=sats*s_scale_factor,
		debris=debs*d_scale_factor)
	), by=c("time"), suffix=c("","_sj1"))

# Repeat rows to match number of rows in fig5_dfrm_u
lp_d1_rep <- as.data.frame(matrix(NA,nrow=nrow(fig5_dfrm_u), ncol=ncol(lp_d1)))
lp_s2_rep <- as.data.frame(matrix(NA,nrow=nrow(fig5_dfrm_u), ncol=ncol(lp_s2)))
lp_sj1_rep <- as.data.frame(matrix(NA,nrow=nrow(fig5_dfrm_u), ncol=ncol(lp_sj1)))

opt_ss_ray_origin_surface_scaled_rep <- as.data.frame(matrix(NA,nrow=nrow(fig5_dfrm_u), ncol=ncol(opt_ss_ray_origin_surface_scaled)))

lp_d1_rep[1:nrow(lp_d1),] <- lp_d1
colnames(lp_d1_rep) <- colnames(lp_d1)

lp_s2_rep[1:nrow(lp_s2),] <- lp_s2
colnames(lp_s2_rep) <- colnames(lp_s2)

lp_sj1_rep[1:nrow(lp_sj1),] <- lp_sj1
colnames(lp_sj1_rep) <- colnames(lp_sj1)

opt_ss_ray_origin_surface_scaled_rep[1:nrow(opt_ss_ray_origin_surface_scaled_rep),] <- opt_ss_ray_origin_surface_scaled
colnames(opt_ss_ray_origin_surface_scaled_rep) <- colnames(opt_ss_ray_origin_surface_scaled)

# Combine data frames
fig5_dfrm <- data.frame(fig5_dfrm_u, lp_d1_rep, lp_s2_rep, lp_sj1_rep) %>%
			mutate(
				time=time.1,
				s2.launches = launches,
				d1.launches = launches.1,
				sj1.launches = launches.2,
				s2.sats = sats.1,
				d1.sats = sats,
				sj1.sats = sats.2,
				s2.debs = debs.1,
				d1.debs = debs,
				sj1.debs = debs.2
			) %>%
			select(-time.1, -launches, -launches.1, -sats, -sats.1,
-debs, -debs.1, -launches.2, -sats.2, -debs.2)

head(fig5_dfrm)

fig5_dfrm <- fig5_dfrm %>%
	mutate(
		opt_ss_surf = opt_ss_ray_origin_surface_scaled_rep,
		launch_rates = launch_rates*s_scale_factor,
		satellites = satellites*s_scale_factor,
		debris = debris*d_scale_factor,
		s2.launches = s2.launches*s_scale_factor,
		d1.launches = d1.launches*s_scale_factor,
		sj1.launches = sj1.launches*s_scale_factor,
		s2.sats = s2.sats*s_scale_factor,
		d1.sats = d1.sats*s_scale_factor,
		sj1.sats = sj1.sats*s_scale_factor,
		s2.debs = s2.debs*d_scale_factor,
		d1.debs = d1.debs*d_scale_factor,
		sj1.debs = sj1.debs*d_scale_factor,
		oa_S_ = oa_S_*s_scale_factor,
		oa_D_ = oa_D_*d_scale_factor
	)

head(fig5_dfrm)

ray_start_idx <- fig5_dfrm %>%
	filter(opt_ss_surf[,2]==0) %>% # second column is debris
	select(opt_ss_surf) %>%
	unlist() %>% as.numeric() %>%
	which.max()

fig5_dfrm$s2.sats[1:200]
fig5_dfrm$s2.launches[1:200]

figure_4time_a <- ggplot(data=fig5_time_dfrm[1:50,], aes(x=time)) +
				geom_line(aes(y=launches_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=launches_d1), linetype="dashed", color= viridis(9)[7], size=1) + 
				geom_line(aes(y=launches_sj1), linetype="dashed", color= viridis(9)[5], size=1) + 
				geom_segment(x=0,y=0, xend=100,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				ylab("Launches") + xlab("Time") +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_4time_b <- ggplot(data=fig5_time_dfrm[1:50,], aes(x=time)) +
				geom_line(aes(y=L(sats_s2, debs_s2, params)), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=L(sats_d1, debs_d1, params)), linetype="dashed", color= viridis(9)[7], size=1) + 
				geom_line(aes(y=L(sats_sj1, debs_sj1, params)), linetype="dashed", color= viridis(9)[5], size=1) + 
				ylab("Collision risk") + xlab("Time") +
				geom_segment(x=0,y=0, xend=100,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_4time_c <- ggplot(data=fig5_time_dfrm[1:50,], aes(x=time)) +
				geom_line(aes(y=satellites_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=satellites_d1), linetype="dashed", color= viridis(9)[7], size=1) + 
				geom_line(aes(y=satellites_sj1), linetype="dashed", color= viridis(9)[5], size=1) + 
				ylab("Satellites") + xlab("Time") +
				geom_segment(x=0,y=0, xend=100,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_4time_d <- ggplot(data=fig5_time_dfrm[1:50,], aes(x=time)) +
				geom_line(aes(y=debris_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=debris_d1), linetype="dashed", color= viridis(9)[7], size=1) + 
				geom_line(aes(y=debris_sj1), linetype="dashed", color= viridis(9)[5], size=1) + 
				ylab("Debris") + xlab("Time") +
				geom_segment(x=0,y=0, xend=100,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

big_phase_panel <- ggplot(data=fig5_dfrm) + 
					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					geom_path(data=fig5_dfrm[1:50,], aes(x=d1.debs,y=d1.sats), linetype="dashed", color=viridis(9)[7], size=1.5) +
					geom_path(data=fig5_dfrm[1:50,], aes(x=s2.debs,y=s2.sats), linetype="dashed", color=viridis(9)[3], size=1.5) +
					geom_path(data=fig5_dfrm[1:50,], aes(x=sj1.debs,y=sj1.sats), linetype="dashed", color=viridis(9)[5], size=1.5) +
					geom_segment(x=fig5_dfrm$opt_ss_surf[ray_start_idx,2],y=fig5_dfrm$opt_ss_surf[ray_start_idx,1], xend=opt_ss[2]*d_scale_factor,yend=opt_ss[1]*s_scale_factor, size=1, linetype="dotted") +
					annotate(geom = "text", x = fig5_dfrm$sj1.debs[3], y = 0.8*fig5_dfrm$sj1.sats[3], label = paste(strwrap("3")), hjust = -1, vjust = 1.5, size = 7, color = viridis(9)[5]) +
					annotate(geom = "text", x = fig5_dfrm$s2.debs[3], y = fig5_dfrm$s2.sats[3], label = paste(strwrap("2")), hjust = -1, vjust = 1.5, size = 7, color = viridis(9)[3]) +
					annotate(geom = "text", x = 1*d_scale_factor, y = 0.6*s_scale_factor, label = paste(strwrap("1")), hjust = 0, vjust = 1.5, size = 7, color = viridis(9)[7]) +
					ylab("Satellites") + xlab("Debris")  +
					stat_contour(aes(x=debris,y=satellites,z=sat_null, group=..level..), size=1, breaks=c(0), color="#313695", alpha = 0.5) +
					stat_contour(aes(x=debris,y=satellites,z=deb_null, group=..level..), size=1, breaks=c(0), color="#A50026", alpha = 0.5) +
					geom_point(aes(x=opt_ss[2]*d_scale_factor,y=opt_ss[1]*s_scale_factor), size=5) +
					geom_text(label = "Dn",	x = 2.75*d_scale_factor, y = 3.75*s_scale_factor, color="#A50026", size=7) +
					geom_text(label = "Sn",	x = 3.6*d_scale_factor, y = 0.9*s_scale_factor, color="#313695", size=7) +
					geom_text(label = "Opt",x = 0.5*d_scale_factor, y = 1.09*fig5_dfrm[100,"d1.sats"], color="black", size=7) +
					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					# horizontal arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	)  +
				scale_x_continuous(labels=comma, limits=c(0,4*d_scale_factor)) + 
				scale_y_continuous(labels=comma, limits=c(0,4*s_scale_factor)) +
					theme_bw() + theme(text=element_text(family="Arial", size=15),panel.grid.minor = element_blank())

big_phase_panel_label <- 
	plot_grid(big_phase_panel + coord_flip(), align="h", axis="1", nrow=1, ncol=1, rel_widths=c(1/2,1/2), labels=c("A"))

fig4_time_panel <- plot_grid(figure_4time_a, figure_4time_b, figure_4time_c, figure_4time_d, align="v", labels=c("B","C","D","E"), nrow=4)

ggsave(
	plot_grid(big_phase_panel_label, fig4_time_panel, align="h", axis="1", nrow=1, ncol=2, rel_widths=c(1/2,1/2), label_size=15),
	   filename=paste0("../../images/figure-7--optimal-phase-trajectories.png"),
	   width=16, height=8, units="in", dpi=300)
